{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "42103cae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://huggingface.co/datasets/mesolitica/semisupervised-corpus/resolve/main/similarity/shuffled-train.json\n",
    "# !wget https://huggingface.co/datasets/mesolitica/semisupervised-corpus/resolve/main/similarity/shuffled-test.json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f21abd15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from bidirectional_mistral import MistralBiModel\n",
    "from transformers import MistralPreTrainedModel\n",
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from typing import Optional, List\n",
    "from torch import nn\n",
    "from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss\n",
    "from transformers.modeling_outputs import SequenceClassifierOutputWithPast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bd862a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MistralForSequenceClassification(MistralPreTrainedModel):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.num_labels = config.num_labels\n",
    "        self.model = MistralBiModel(config)\n",
    "        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)\n",
    "\n",
    "        # Initialize weights and apply final processing\n",
    "        self.post_init()\n",
    "        \n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        position_ids: Optional[torch.LongTensor] = None,\n",
    "        past_key_values: Optional[List[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        labels: Optional[torch.LongTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "    ):\n",
    "        r\"\"\"\n",
    "        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\n",
    "            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n",
    "            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n",
    "            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n",
    "        \"\"\"\n",
    "        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
    "\n",
    "        transformer_outputs = self.model(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            position_ids=position_ids,\n",
    "            past_key_values=past_key_values,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "        pooled_output = transformer_outputs[0][:, 0]\n",
    "        logits = self.score(pooled_output)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            if self.config.problem_type is None:\n",
    "                if self.num_labels == 1:\n",
    "                    self.config.problem_type = \"regression\"\n",
    "                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):\n",
    "                    self.config.problem_type = \"single_label_classification\"\n",
    "                else:\n",
    "                    self.config.problem_type = \"multi_label_classification\"\n",
    "\n",
    "            if self.config.problem_type == \"regression\":\n",
    "                loss_fct = MSELoss()\n",
    "                if self.num_labels == 1:\n",
    "                    loss = loss_fct(logits.squeeze(), labels.squeeze())\n",
    "                else:\n",
    "                    loss = loss_fct(logits, labels)\n",
    "            elif self.config.problem_type == \"single_label_classification\":\n",
    "                loss_fct = CrossEntropyLoss()\n",
    "                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))\n",
    "            elif self.config.problem_type == \"multi_label_classification\":\n",
    "                loss_fct = BCEWithLogitsLoss()\n",
    "                loss = loss_fct(logits, labels)\n",
    "        if not return_dict:\n",
    "            output = (logits,) + transformer_outputs[2:]\n",
    "            return ((loss,) + transformer_outputs) if loss is not None else output\n",
    "        \n",
    "        return SequenceClassifierOutputWithPast(\n",
    "            loss=loss,\n",
    "            logits=logits,\n",
    "            past_key_values=transformer_outputs.past_key_values,\n",
    "            hidden_states=transformer_outputs.hidden_states,\n",
    "            attentions=transformer_outputs.attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b30b8a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "57905bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('mesolitica/malaysian-mistral-64M-MLM-512')\n",
    "# tokenizer.padding_side = 'left'\n",
    "config = AutoConfig.from_pretrained('mesolitica/malaysian-mistral-64M-MLM-512')\n",
    "config.problem_type = \"single_label_classification\"\n",
    "config.label2id = {'contradiction': 0, 'entailment': 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9398590c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation=\"flash_attention_2\"` instead.\n",
      "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour\n",
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
      "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralForSequenceClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n",
      "Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in MistralBiModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained(\"openai/whisper-tiny\", attn_implementation=\"flash_attention_2\", torch_dtype=torch.float16)`\n",
      "Some weights of MistralForSequenceClassification were not initialized from the model checkpoint at mesolitica/malaysian-mistral-64M-MLM-512 and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model = MistralForSequenceClassification.from_pretrained(\n",
    "    'mesolitica/malaysian-mistral-64M-MLM-512', \n",
    "    config = config,\n",
    "    use_flash_attention_2 = True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "46ec48b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainable_parameters = [param for param in model.parameters() if param.requires_grad]\n",
    "trainer = torch.optim.AdamW(trainable_parameters, lr = 2e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "db9478b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, train_Y = [], []\n",
    "with open('shuffled-train.json') as fopen:\n",
    "    for l in fopen:\n",
    "        l = json.loads(l)\n",
    "        train_X.append(l['src'])\n",
    "        train_Y.append(l['label'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ce57e047",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16037"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_X, test_Y = [], []\n",
    "with open('shuffled-test.json') as fopen:\n",
    "    for l in fopen:\n",
    "        l = json.loads(l)\n",
    "        test_X.append(l['src'])\n",
    "        test_Y.append(l['label'])\n",
    "        \n",
    "len(test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "67a4fcba",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 16\n",
    "epoch = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5eab2ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate import Accelerator\n",
    "accelerator = Accelerator(mixed_precision = 'bf16')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3a1b3101",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = accelerator.device\n",
    "model, trainer = accelerator.prepare(model, trainer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3e5bf4f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████| 68348/68348 [50:43<00:00, 22.46it/s, loss=0.352]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, loss: 0.4492138416358152, dev_predicted: 0.8014\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████| 68348/68348 [50:41<00:00, 22.47it/s, loss=0.215]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, loss: 0.3653065696207812, dev_predicted: 0.807\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████| 68348/68348 [50:42<00:00, 22.46it/s, loss=0.367]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, loss: 0.3176246317962609, dev_predicted: 0.8152\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████| 68348/68348 [50:45<00:00, 22.44it/s, loss=0.148]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, loss: 0.2712462457894229, dev_predicted: 0.8085\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "best_dev_acc = -np.inf\n",
    "patient = 1\n",
    "current_patient = 0\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size))\n",
    "    losses = []\n",
    "    for i in pbar:\n",
    "        trainer.zero_grad()\n",
    "        x = train_X[i: i + batch_size]\n",
    "        y = np.array(train_Y[i: i + batch_size])\n",
    "        \n",
    "        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 4096)\n",
    "        padded['labels'] = torch.from_numpy(y)\n",
    "        for k in padded.keys():\n",
    "            padded[k] = padded[k].cuda()\n",
    "\n",
    "        padded.pop('token_type_ids', None)\n",
    "\n",
    "        o = model(**padded, use_cache = False)\n",
    "        loss = o.loss\n",
    "        \n",
    "        accelerator.backward(loss)\n",
    "\n",
    "        grad_norm = torch.nn.utils.clip_grad_norm_(trainable_parameters, 1.0)\n",
    "        trainer.step()\n",
    "            \n",
    "        losses.append(float(loss))\n",
    "        pbar.set_postfix(loss = float(loss))\n",
    "        \n",
    "        \n",
    "    dev_predicted = []\n",
    "    for i in range(0, len(test_X[:10000]), batch_size):\n",
    "        x = test_X[i: i + batch_size]\n",
    "        y = np.array(test_Y[i: i + batch_size])\n",
    "        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 4096)\n",
    "        padded['labels'] = torch.from_numpy(y)\n",
    "        for k in padded.keys():\n",
    "            padded[k] = padded[k].cuda()\n",
    "            \n",
    "        padded.pop('token_type_ids', None)\n",
    "        \n",
    "        o = model(**padded, use_cache = False)\n",
    "        loss = o.loss\n",
    "        pred = o.logits\n",
    "        \n",
    "        dev_predicted.append((pred.argmax(axis = 1).detach().cpu().numpy() == y).mean())\n",
    "        \n",
    "    dev_predicted = np.mean(dev_predicted)\n",
    "    \n",
    "    print(f'epoch: {e}, loss: {np.mean(losses)}, dev_predicted: {dev_predicted}')\n",
    "    \n",
    "    if dev_predicted >= best_dev_acc:\n",
    "        best_dev_acc = dev_predicted\n",
    "        current_patient = 0\n",
    "        model.save_pretrained('64M')\n",
    "    else:\n",
    "        current_patient += 1\n",
    "    \n",
    "    if current_patient >= patient:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efe9c2e0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
